/* ****************************************************************************
 * Copyright: 2017-2025 RAYLASE GmbH
 * This source code is the proprietary confidential property of RAYLASE GmbH.
 * Reproduction, publication, or any form of distribution to
 * any party other than the licensee is strictly prohibited.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 */

#pragma once

#include "Socket.h"

#include <Defer.h>
#include <charconv>
#include <compare>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>

enum class IPAddressFamily
{
	None,
	IPv4,
	IPv6
};

union IPAddressIPv4 {
	unsigned char AddressBytes[4];
	in_addr Address;
};

struct IPAddressIPv6
{
	union {
		unsigned char AddressBytes[16];
		in6_addr Address;
	};
	uint32_t ScopeID;
};

class IPAddress
{
public:
	IPAddressFamily AddressFamily = IPAddressFamily::None;
	union {
		IPAddressIPv4 IPv4;
		IPAddressIPv6 IPv6;
	};

	IPAddress() = default;
	inline constexpr IPAddress(in_addr address) noexcept
	    : AddressFamily(IPAddressFamily::IPv4),
	      IPv4{.Address = address}
	{}
	inline constexpr IPAddress(in6_addr address, uint32_t scopeID) noexcept
	    : AddressFamily(IPAddressFamily::IPv6),
	      IPv6{.Address = address, .ScopeID = scopeID}
	{}
	inline constexpr IPAddress(const unsigned char (&addressIPv4)[4]) noexcept
	    : AddressFamily(IPAddressFamily::IPv4),
	      IPv4{
	          .AddressBytes = {addressIPv4[0], addressIPv4[1], addressIPv4[2], addressIPv4[3]}
    }
	{}
	inline constexpr IPAddress(const unsigned char (&addressIPv6)[16], uint32_t scopeID) noexcept
	    : AddressFamily(IPAddressFamily::IPv6),
	      IPv6{
	          .AddressBytes = {addressIPv6[0], addressIPv6[1], addressIPv6[2], addressIPv6[3], addressIPv6[4], addressIPv6[5], addressIPv6[6], addressIPv6[7],
	                           addressIPv6[8], addressIPv6[9], addressIPv6[10], addressIPv6[11], addressIPv6[12], addressIPv6[13], addressIPv6[14], addressIPv6[15]},
	          .ScopeID = scopeID
    }
	{}

	static inline IPAddress Parse(std::string_view ip)
	{
		for (std::uint32_t i = 0; i < ip.size(); i++)
		{
			if (ip[i] == '.')
			{
				in_addr address;
				inet_pton(AF_INET, std::string(ip).c_str(), &address);
				return address;
			}
			if (ip[i] == ':')
			{
				if (ip[0] == '[')
				{
					size_t pos = ip.find(']');
					if (pos == std::string::npos)
						break;
					ip = ip.substr(1, pos - 1);
				}
				size_t scopePos = ip.find('%');
				uint32_t scopeID = 0;
				if (scopePos != std::string::npos)
				{
					std::string_view scopeIdStr = ip.substr(scopePos + 1);
					const auto error = std::from_chars(scopeIdStr.data(), scopeIdStr.data() + scopeIdStr.size(), scopeID);
					if (error.ec != std::errc{})
						break;
					ip = ip.substr(0, scopePos);
				}
				in6_addr address;
				inet_pton(AF_INET6, std::string(ip).c_str(), &address);
				return {address, scopeID};
			}
		}
		throw std::runtime_error("Invalid IP address.");
	}
	static inline IPAddress ResolveHostNameOrIP(const std::string& hostName, IPAddressFamily restrictToFamily = IPAddressFamily::None)
	{
#ifdef _WIN32
		WSAStartup();
#endif
		int family = AF_UNSPEC;
		if (restrictToFamily == IPAddressFamily::IPv4)
			family = AF_INET;
		else if (restrictToFamily == IPAddressFamily::IPv6)
			family = AF_INET6;
		struct addrinfo hints = {};
		hints.ai_family = family;
		hints.ai_socktype = SOCK_STREAM;
		struct addrinfo* result;
		if (getaddrinfo(hostName.data(), nullptr, &hints, &result) == 0)
		{
			DEFER { freeaddrinfo(result); };
			for (struct addrinfo* p = result; p != nullptr; p = p->ai_next)
				if (p->ai_family == AF_INET && restrictToFamily != IPAddressFamily::IPv6)
				{ // IPv4
					struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
					return {ipv4->sin_addr};
				}
				else if (p->ai_family == AF_INET6 && restrictToFamily != IPAddressFamily::IPv4)
				{ // IPv6
					struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
					return {ipv6->sin6_addr, ipv6->sin6_scope_id};
				}
		}
		throw std::runtime_error("Invalid IP address or unable to resolve host name with DNS.");
	}
	inline std::string ToString() const
	{
		if (AddressFamily == IPAddressFamily::IPv4)
		{
			char buf[INET_ADDRSTRLEN];
			inet_ntop(AF_INET, &IPv4.Address, buf, INET_ADDRSTRLEN);
			return std::string(buf, strnlen(buf, INET_ADDRSTRLEN));
		}
		else if (AddressFamily == IPAddressFamily::IPv6)
		{
			char buf[INET6_ADDRSTRLEN + 3];
			inet_ntop(AF_INET6, &IPv6.Address, buf, INET6_ADDRSTRLEN);
			std::string addr(buf, strnlen(buf, INET6_ADDRSTRLEN));
			if (IPv6.ScopeID != 0)
				addr += '%' + std::to_string(IPv6.ScopeID);
			return addr;
		}
		return std::string();
	}

	struct Sockaddr
	{
		union {
			struct sockaddr Any;
			struct sockaddr_in IPv4;
			struct sockaddr_in6 IPv6;
		};
		std::int32_t Len;
	};

	inline Sockaddr GetSockaddr(uint16_t port) const
	{
		if (AddressFamily == IPAddressFamily::IPv4)
			return {
			    .IPv4 = {.sin_family = AF_INET, .sin_port = htons(port), .sin_addr = IPv4.Address, .sin_zero = {}},
                  .Len = sizeof(sockaddr_in)
            };
		else if (AddressFamily == IPAddressFamily::IPv6)
			return {
			    .IPv6 = {.sin6_family = AF_INET6, .sin6_port = htons(port), .sin6_flowinfo = {}, .sin6_addr = IPv6.Address, .sin6_scope_id = IPv6.ScopeID},
			    .Len = sizeof(sockaddr_in6)
            };
		else
			throw std::runtime_error("\"GetSockaddr\" failed because there was no IPv4 or IPv6 address given.");
	}

	inline constexpr bool IsIPv6LinkLocal() const noexcept
	{
		return AddressFamily == IPAddressFamily::IPv6 && IN6_IS_ADDR_LINKLOCAL(&IPv6.Address);
	}

	inline constexpr bool operator==(const IPAddress& other) const
	{
		if (AddressFamily != other.AddressFamily)
			return false;
		if (AddressFamily == IPAddressFamily::IPv4)
			return std::memcmp(IPv4.AddressBytes, other.IPv4.AddressBytes, sizeof(IPv4.AddressBytes)) == 0;
		if (AddressFamily == IPAddressFamily::IPv6)
			return std::memcmp(IPv6.AddressBytes, other.IPv6.AddressBytes, sizeof(IPv6.AddressBytes)) == 0 && IPv6.ScopeID == other.IPv6.ScopeID;
		return true;
	}

	inline constexpr bool operator!=(const IPAddress& rhs) const noexcept { return !(*this == rhs); }

	std::strong_ordering operator<=>(const IPAddress& other) const noexcept
	{
		if (AddressFamily != other.AddressFamily)
			return AddressFamily <=> other.AddressFamily;
		if (AddressFamily == IPAddressFamily::IPv4)
			return std::memcmp(IPv4.AddressBytes, other.IPv4.AddressBytes, sizeof(IPv4.AddressBytes)) <=> 0;
		if (AddressFamily == IPAddressFamily::IPv6)
		{
			auto cmp = std::memcmp(IPv6.AddressBytes, other.IPv6.AddressBytes, sizeof(IPv6.AddressBytes));
			return cmp != 0 ? (cmp < 0 ? std::strong_ordering::less : std::strong_ordering::greater) : IPv6.ScopeID <=> other.IPv6.ScopeID;
		}
		return std::strong_ordering::equal;
	}
};

static constexpr IPAddress IPAddressIPv4Any = {
    {0, 0, 0, 0}
};
static constexpr IPAddress IPAddressIPv6Any = {
    {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
    0
};
static constexpr IPAddress IPAddressIPv4Loopback = {
    {127, 0, 0, 1}
};
static constexpr IPAddress IPAddressIPv6Loopback = {
    {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
    0
};

namespace std
{
	template<> struct hash<IPAddress>
	{
		size_t operator()(const IPAddress& x) const noexcept { return (66 + std::hash<std::string>()(x.ToString())); }
	};
} // namespace std